import cv2
import os
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

def image_inpainting(input_location, input_mask_location):
    # 从文件读取输入图像和蒙版
    input = {
        'img': input_location,
        'mask': input_mask_location,
    }
    
    # 加载模型进行图像修复
    inpainting = pipeline(Tasks.image_inpainting, model='damo/cv_fft_inpainting_lama')
    result = inpainting(input)
    vis_img = result[OutputKeys.OUTPUT_IMG]

    # 生成输出图像保存路径
    output_path = os.path.join('result', os.path.basename(input_location))
    
    # 保存修复后的图像
    cv2.imwrite(output_path, vis_img)
    
    return result

if __name__ == '__main__':
    # 源文件目录
    input_folder = '2k/input'
    # 蒙版图片目录
    mask_folder = '2k/mask'

    # 遍历输入文件夹中的所有图像和蒙版，并进行图像修复
    for input_file in os.listdir(input_folder):
        if input_file.endswith('.png'):
            # 构造蒙版文件名
            mask_file = input_file
            # mask_file = 'mask_' + input_file
            # mask_file = mask_file.replace('.JPEG', '.png')
            # 这里是
            mask_file = input_file
            
            # 构造输入文件和蒙版文件的完整路径
            input_path = os.path.join(input_folder, input_file)
            mask_path = os.path.join(mask_folder, mask_file)
            
            # 进行图像修复，并将修复后的图像保存到result文件夹
            image_inpainting(input_path, mask_path)